# Copyright (c) 2009 IW.
# All rights reserved.
#
# Author: liuguiyang <liuguiyangnwpu@gmail.com>
# Date:   2017/10/17

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from matplotlib import pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D


ssd_res_details_dict = dict(
    plane=[[0.0, 0.0, 0.019157553419517023, 1.0],
           [0.0, 0.125, 0.6798647469458988, 0.9960051134547779],
           [0.0, 0.25, 0.8265070077561573, 0.9771557271557272],
           [0.0, 0.375, 0.9122972760195245, 0.9430338541666666],
           [0.0, 0.5, 0.9489204844655081, 0.8959231024196221],
           [0.0, 0.625, 0.9621342512908778, 0.8604412519240636],
           [0.0, 0.75, 0.9693640397211071, 0.8197248525996069],
           [0.0, 0.875, 0.9820438341695273, 0.7286442006269592],
           [0.0, 1.0, 0, 0],
           [0.125, 0.0, 0.01907557973610546, 0.9957210776545167],
           [0.125, 0.125, 0.6539048865619547, 0.9579737935442634],
           [0.125, 0.25, 0.7811947203701184, 0.9235842985842986],
           [0.125, 0.375, 0.8560856558022358, 0.8849283854166666],
           [0.125, 0.5, 0.8995962787431981, 0.8493536625787206],
           [0.125, 0.625, 0.9150889271371199, 0.8183683940482298],
           [0.125, 0.75, 0.9205577857595605, 0.7784527425406468],
           [0.125, 0.875, 0.9405862160021125, 0.697884012539185],
           [0.125, 1.0, 0, 0],
           [0.25, 0.0, 0.018963245429208134, 0.9898573692551506],
           [0.25, 0.125, 0.6533595113438045, 0.9571748162352189],
           [0.25, 0.25, 0.7805143556946523, 0.9227799227799228],
           [0.25, 0.375, 0.8552983782081562, 0.8841145833333334],
           [0.25, 0.5, 0.8987186238371072, 0.8485250248591316],
           [0.25, 0.625, 0.914132721361637, 0.8175132546605096],
           [0.25, 0.75, 0.9197126558208325, 0.7777380739681973],
           [0.25, 0.875, 0.9405862160021125, 0.697884012539185],
           [0.25, 1.0, 0, 0],
           [0.375, 0.0, 0.018532124575710286, 0.9673534072900158],
           [0.375, 0.125, 0.6411431064572426, 0.9392777245126238],
           [0.375, 0.25, 0.7685399374064499, 0.9086229086229086],
           [0.375, 0.375, 0.8442764918910408, 0.8727213541666666],
           [0.375, 0.5, 0.8901176057574162, 0.8404043752071594],
           [0.375, 0.625, 0.9082042455536431, 0.8122113904566445],
           [0.375, 0.75, 0.9156982886118741, 0.774343398249062],
           [0.375, 0.875, 0.9392659096910483, 0.6969043887147336],
           [0.375, 1.0, 0, 0],
           [0.5, 0.0, 0.016236861440186535, 0.8475435816164818],
           [0.5, 0.125, 0.5458115183246073, 0.7996164908916586],
           [0.5, 0.25, 0.6547829636685263, 0.7741312741312741],
           [0.5, 0.375, 0.723665564478035, 0.748046875],
           [0.5, 0.5, 0.7691767596980867, 0.7262180974477959],
           [0.5, 0.625, 0.7936507936507936, 0.7097656918077647],
           [0.5, 0.75, 0.8115360236636383, 0.6862604966946578],
           [0.5, 0.875, 0.8581991021917085, 0.6367554858934169],
           [0.5, 1.0, 0, 0],
           [0.625, 0.0, 0.010453162666148512, 0.5456418383518225],
           [0.625, 0.125, 0.3430410122164049, 0.5025567273889422],
           [0.625, 0.25, 0.4159749625799429, 0.4917953667953668],
           [0.625, 0.375, 0.4682727129585892, 0.4840494791666667],
           [0.625, 0.5, 0.5093909074951729, 0.4809413324494531],
           [0.625, 0.625, 0.5385350927519602, 0.4816145031640157],
           [0.625, 0.75, 0.5704627086414537, 0.4824012864034304],
           [0.625, 0.875, 0.6395563770794824, 0.4745297805642633],
           [0.625, 1.0, 0, 0],
           [0.75, 0.0, 0.003971169551937918, 0.20729001584786053],
           [0.75, 0.125, 0.13350785340314136, 0.1955896452540748],
           [0.75, 0.25, 0.16478432439787727, 0.19481981981981983],
           [0.75, 0.375, 0.18926153361675327, 0.19563802083333334],
           [0.75, 0.5, 0.2083552747059856, 0.19671859463042757],
           [0.75, 0.625, 0.219162363740677, 0.19599794766546946],
           [0.75, 0.75, 0.23663638284386224, 0.20010720028586743],
           [0.75, 0.875, 0.26643781357274887, 0.19768808777429467],
           [0.75, 1.0, 0, 0],
           [0.875, 0.0, 0.00031878654660052097, 0.01664025356576862],
           [0.875, 0.125, 0.010798429319371727, 0.015819750719079578],
           [0.875, 0.25, 0.013471220574227786, 0.015926640926640926],
           [0.875, 0.375, 0.01543064084396158, 0.015950520833333332],
           [0.875, 0.5, 0.01720203615938213, 0.016241299303944315],
           [0.875, 0.625, 0.018167909734174793, 0.01624764836668377],
           [0.875, 0.75, 0.019015423621381786, 0.016080042880114345],
           [0.875, 0.875, 0.02033271719038817, 0.015086206896551725],
           [0.875, 1.0, 0, 0],
           [1.0, 0.0, 0, 0],
           [1.0, 0.125, 0, 0],
           [1.0, 0.25, 0, 0],
           [1.0, 0.375, 0, 0],
           [1.0, 0.5, 0, 0],
           [1.0, 0.625, 0, 0],
           [1.0, 0.75, 0, 0],
           [1.0, 0.875, 0, 0],
           [1.0, 1.0, 0, 0]])


dssd_res_details_dict = dict(
    plane=[[0.0, 0.0, 0.021824982889184136, 1.0],
           [0.0, 0.125, 0.36231619880908983, 1.0],
           [0.0, 0.25, 0.764569664222308, 0.9966459835653194],
           [0.0, 0.375, 0.8860816944024206, 0.9822237128961931],
           [0.0, 0.5, 0.9371957156767283, 0.9689597315436241],
           [0.0, 0.625, 0.9729269538566321, 0.9595298068849706],
           [0.0, 0.75, 0.9931630082763584, 0.9283551967709385],
           [0.0, 0.875, 0.9997793468667255, 0.7630515325025261],
           [0.0, 1.0, 0, 0],
           [0.125, 0.0, 0.021799362416230202, 0.9988260942478618],
           [0.125, 0.125, 0.3545388261028071, 0.9785342948180447],
           [0.125, 0.25, 0.7439855911488485, 0.9698138520878752],
           [0.125, 0.375, 0.8748865355521936, 0.9698138520878752],
           [0.125, 0.5, 0.9307043167802661, 0.962248322147651],
           [0.125, 0.625, 0.962710710028946, 0.9494542401343409],
           [0.125, 0.75, 0.9838071248650594, 0.9196098217288934],
           [0.125, 0.875, 0.9964695498676082, 0.7605254294375211],
           [0.125, 1.0, 0, 0],
           [0.25, 0.0, 0.021781062078405966, 0.9979875901391917],
           [0.25, 0.125, 0.3545388261028071, 0.9785342948180447],
           [0.25, 0.25, 0.7439855911488485, 0.9698138520878752],
           [0.25, 0.375, 0.8748865355521936, 0.9698138520878752],
           [0.25, 0.5, 0.9307043167802661, 0.962248322147651],
           [0.25, 0.625, 0.9621998978375618, 0.9489504617968094],
           [0.25, 0.75, 0.9838071248650594, 0.9196098217288934],
           [0.25, 0.875, 0.9958075904677847, 0.76002020882452],
           [0.25, 1.0, 0, 0],
           [0.375, 0.0, 0.02168590032171994, 0.993627368774107],
           [0.375, 0.125, 0.3545388261028071, 0.9785342948180447],
           [0.375, 0.25, 0.7439855911488485, 0.9698138520878752],
           [0.375, 0.375, 0.8748865355521936, 0.9698138520878752],
           [0.375, 0.5, 0.9307043167802661, 0.962248322147651],
           [0.375, 0.625, 0.9621998978375618, 0.9489504617968094],
           [0.375, 0.75, 0.9838071248650594, 0.9196098217288934],
           [0.375, 0.875, 0.9958075904677847, 0.76002020882452],
           [0.375, 1.0, 0, 0],
           [0.5, 0.0, 0.02100512775465835, 0.9624350159315781],
           [0.5, 0.125, 0.3456677603597035, 0.9540499748448767],
           [0.5, 0.25, 0.729576739997427, 0.9510313600536643],
           [0.5, 0.375, 0.8579425113464448, 0.9510313600536643],
           [0.5, 0.5, 0.9131775397598182, 0.9441275167785235],
           [0.5, 0.625, 0.9450025540609569, 0.9319899244332494],
           [0.5, 0.75, 0.9688736955739474, 0.9056508577194753],
           [0.5, 0.875, 0.9911738746690203, 0.756483664533513],
           [0.5, 1.0, 0, 0],
           [0.625, 0.0, 0.018314978094495626, 0.8391749119570686],
           [0.625, 0.125, 0.3033175355450237, 0.8371625020962603],
           [0.625, 0.25, 0.6422230798919336, 0.8371625020962603],
           [0.625, 0.375, 0.7552193645990923, 0.8371625020962603],
           [0.625, 0.5, 0.8044466082440765, 0.8317114093959731],
           [0.625, 0.625, 0.8329644134173335, 0.8214945424013435],
           [0.625, 0.75, 0.8573227779776899, 0.8013790783720148],
           [0.625, 0.875, 0.8801853486319505, 0.6717750084203435],
           [0.625, 1.0, 0, 0],
           [0.75, 0.0, 0.005072853644878285, 0.23243333892336074],
           [0.75, 0.125, 0.08397132093814558, 0.23176253563642463],
           [0.75, 0.25, 0.17779493117200565, 0.23176253563642463],
           [0.75, 0.375, 0.2090771558245083, 0.23176253563642463],
           [0.75, 0.5, 0.22411554690035704, 0.23171140939597315],
           [0.75, 0.625, 0.23395198365400988, 0.23073047858942067],
           [0.75, 0.75, 0.23893486865779057, 0.22334342415068953],
           [0.75, 0.875, 0.22484554280670785, 0.1716066015493432],
           [0.75, 1.0, 0, 0],
           [0.875, 0.0, 6.22211486024032e-05, 0.0028509139694784503],
           [0.875, 0.125, 0.0010329323125531657, 0.0028509139694784503],
           [0.875, 0.25, 0.0021870577640550623, 0.0028509139694784503],
           [0.875, 0.375, 0.002571860816944024, 0.0028509139694784503],
           [0.875, 0.5, 0.00275884453099643, 0.0028523489932885905],
           [0.875, 0.625, 0.0028946024178443724, 0.0028547439126784214],
           [0.875, 0.75, 0.0030586541921554518, 0.002859064917591658],
           [0.875, 0.875, 0.0037511032656663726, 0.0028629168070057258],
           [0.875, 1.0, 0, 0],
           [1.0, 0.0, 0, 0],
           [1.0, 0.125, 0, 0],
           [1.0, 0.25, 0, 0],
           [1.0, 0.375, 0, 0],
           [1.0, 0.5, 0, 0],
           [1.0, 0.625, 0, 0],
           [1.0, 0.75, 0, 0],
           [1.0, 0.875, 0, 0],
           [1.0, 1.0, 0, 0]])


# [iou, prob, precision, recall]
iou_limit, prob_limit = 0.4, 0.4
for sign_name, sign_detail in ssd_res_details_dict.items():
    print(sign_name)
    for item in sign_detail:
        if item[0] == iou_limit and item[1] == prob_limit:
            print(item)

fig = plt.figure(figsize=(7, 5))
ax = Axes3D(fig)

sub_sign_info = ssd_res_details_dict["plane"]
sign_np = np.array(sub_sign_info)
print(sign_np.shape)
precision_np = sign_np[:, 2]
# IOU
X = np.linspace(0, 1, 9)
# PROB
Y = np.linspace(0, 1, 9)
X, Y = np.meshgrid(X, Y)
Z = precision_np.reshape((9, 9)).T
# print(Z)

surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='coolwarm',
                linewidth=1, antialiased=True)
fig.colorbar(surf, shrink=0.5, aspect=5)
ax.set_xlabel("IOU")
ax.set_ylabel("PROB")
ax.set_zlabel("Percision")
plt.savefig("percision.png", format='png')
plt.show()

# import pandas as pd
# import csv
# 
# with open("ssd.run.log", "r") as log_reader:
#     cur_idx = 0
#     res_list = list()
#     for line in log_reader:
#         cur_idx += 1
#         info = line.strip().split("]")[-1].strip()
#         if cur_idx % 3 == 1:
#             res_list.append(list())
#             iteration, loss_info = info.split(",")
#             iter_cnt = iteration.split(" ")[-1].strip()
#             loss_val = loss_info.split("=")[-1].strip()
#             res_list[-1].append(iter_cnt)
#             res_list[-1].append(loss_val)
#         if cur_idx % 3 == 2:
#             mbox_loss = info.split(":")[-1].strip().split("(")[0].strip().split(
#                 "=")[-1].strip()
#             res_list[-1].append(mbox_loss)
#         if cur_idx % 3 == 0:
#             lr_val = info.split(",")[-1].strip().split("=")[-1].strip()
#             res_list[-1].append(lr_val)
#     # print(res_list)

# sub_res_list = []
# for i in range(len(res_list)):
#     if i % 50 == 0:
#         sub_res_list.append(res_list[i])
#         print("{},{},{}".format(
#             res_list[i][0], res_list[i][1], res_list[i][2], res_list[i][3]))
# 
# with open("a.csv", "w") as h:
#     writer = csv.writer(h)
#     writer.writerows(sub_res_list)
# 
# pd_info = pd.DataFrame(
#     data=sub_res_list, columns=["iter", "loss", "mbox_loss"])
# pd_info_iter = pd_info.set_index("iter")
# pd_info_iter = pd_info_iter.apply(pd.to_numeric)
# pd_info_iter.plot()
# plt.show()

# import os
# import argparse
# 
# def dealwith_image_list(src_path, out_path, frame_nums=16):
#     if not os.path.exists(src_path):
#         raise IOError("Not %s path" % src_path)
#     res_list = list()
#     with open(src_path, "r") as reader:
#         for line in reader:
#             line = line.strip()
#             image_dir, image_label = line.split()
#             img_nums = len(os.listdir(image_dir))
#             for idx in range(1, img_nums, frame_nums):
#                 if idx + frame_nums > img_nums:
#                     idx = img_nums - frame_nums
#                 res_list.append([image_dir, idx, image_label])
#     with open(out_path, "w") as writer:
#         for item in res_list:
#             writer.write("{} {} {}\n".format(item[0], item[1], item[2]))
# 
# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()
#     parser.add_argument("-i", "--input", help="input source path")
#     parser.add_argument("-o", "--output", help="output path")
#     parser.add_argument("-f", "--fps", default=16, help="frame number in each group")
#     args = parser.parse_args()
#     src_path = args.input
#     out_path = args.output
#     frame_nums = args.fps
#     dealwith_image_list(src_path, out_path, frame_nums)